def solve(intervals):
    intervals.sort(key=lambda x: -x[1])
    # print(intervals)
    ls = intervals[0]
    res = []
    for i in range(1, len(intervals)):
        left, right = intervals[i]
        if right < ls[0]:
            res.append(ls)
            ls = [left, right]
        elif right >= ls[0] and left <= ls[0]:
            ls = [left, ls[1]]

    res.append(ls)
    res.sort(key=lambda x: x[0])
    # sorted(res, key=lambda x: x[1])
    return res


if __name__ == "__main__":
    intervals = [[1, 3], [2, 6], [8, 10], [15, 18]]
    print(solve(intervals))
